{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine Tuning using ReFT Adapters\n",
    "\n",
    "In this tutorial, we will be demonstrating how to fine-tune a language model using [Representation Finetuning for Language Models](https://arxiv.org/abs/2404.03592)\n",
    "\n",
    "We will use a lightweight encoder model and focus on fine tuning via ReFT adapters rather than the traditional full model fine tuning.\n",
    "\n",
    "For more information on the ReFT, you can always visit their GitHub [page](https://github.com/stanfordnlp/pyreft)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Installation\n",
    "\n",
    "Before we can get started, we need to ensure the proper packages are installed. Here's a breakdown of what we need:\n",
    "\n",
    "- `adapters` and `accelerate` for efficient fine-tuning and training optimization\n",
    "- `evaluate` for metric computation and model evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -qq evaluate>=0.30\n",
    "!pip install -qq -U adapters accelerate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset\n",
    "\n",
    "In this tutorial, we will be using the `nyu-mll/glue` dataset, created by the GLUE benchmark organization. The GLUE dataset is a collection of language datasets used to benchmark and analyze language systems. We will be using the matched Multi-Genre Natural Language Inference dataset, where the language model is tasked to classify a hypothesis statement given a premise. There are 3 classifications: \"neutral\", \"contradiction\", and \"entailment\".\n",
    "\n",
    "For comparison purposes, we will finetune roberta-base as used in the original paper, but you can swap out the model for any that you prefer if needed. We will also use a similar set of hyperparameters that was used in the paper for both the model training and the reft adapter. We will train for a reduced number of epochs to allow for less computational and time usage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset(\"nyu-mll/glue\", \"mnli\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])\n"
     ]
    }
   ],
   "source": [
    "print(ds.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use the `train` and `validation_matched` splits for our training and testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = ds[\"train\"]\n",
    "eval_dataset = ds[\"validation_matched\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#initialize the model\n",
    "model_name_or_path = \"roberta-base\"\n",
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Preprocessing\n",
    "\n",
    "We will return the example and the hypothesis statements with max padding, and remove the column `idx` as it is not needed for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_function(example):\n",
    "    return tokenizer(example['premise'], example['hypothesis'], return_tensors = \"pt\", truncation=True, padding='max_length')\n",
    "\n",
    "train_dataset = train_dataset.map(preprocess_function, batched=True)\n",
    "eval_dataset = eval_dataset.map(preprocess_function, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = train_dataset.remove_columns([\"idx\"])\n",
    "eval_dataset = eval_dataset.remove_columns([\"idx\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import default_data_collator\n",
    "\n",
    "data_collator = default_data_collator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model and Adapter initialization\n",
    "\n",
    "We load the `roberta-base` model along with the `LoReftConfig`. We can initalize a `reft` config with only one line of code, and can add it to our base model using the `add_adapter` function. On top of that, we can add a classification head to our adapter specifying 3 labels.\n",
    "\n",
    "For more information on the `reft` adapter implementation in `adapters`, you can visit our docs page https://docs.adapterhub.ml/methods.html#reft for more explanations on supported configs and their corresponding parameters\n",
    "\n",
    "Don't forget to activate the adapter that you want to train on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of RobertaAdapterModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['heads.default.3.bias', 'roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from adapters import AutoAdapterModel, LoReftConfig\n",
    "model = AutoAdapterModel.from_pretrained(model_name_or_path)\n",
    "\n",
    "config = LoReftConfig(r = 1, prefix_positions = 1, dropout = 0.05)\n",
    "model.add_adapter(\"loreft_adapter\", config=config)\n",
    "model.add_classification_head(\"loreft_adapter\", num_labels=3)\n",
    "model.train_adapter(\"loreft_adapter\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "================================================================================\n",
      "Name                     Architecture         #Param      %Param  Active   Train\n",
      "--------------------------------------------------------------------------------\n",
      "loreft_adapter           reft                 36,888       0.030       1       1\n",
      "--------------------------------------------------------------------------------\n",
      "Full model                               124,645,632     100.000               0\n",
      "================================================================================\n"
     ]
    }
   ],
   "source": [
    "print(model.adapter_summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation\n",
    "\n",
    "We'll use accuracy as our main metric to evaluate the perforce of the reft model on the `mnli` dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "import numpy as np\n",
    "accuracy_metric = evaluate.load(\"accuracy\")\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return accuracy_metric.compute(predictions=predictions, references=labels)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training\n",
    "\n",
    "Now we are ready to train our model. We will use the same set of hyper-parameters used to train the `mnli` dataset in the original paper except for the number of epochs to train on. For this tutorial, we will only train on 2 epochs for demo purposes, and it is usually expected to train on more epochs. You are always welcome to tweak the hyperparameters to your own needs!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from adapters import AdapterTrainer\n",
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='./results',\n",
    "    eval_strategy='epoch',\n",
    "    learning_rate=6e-4,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    num_train_epochs=2,\n",
    "    weight_decay=0.00,\n",
    "    optim = \"adamw_hf\",\n",
    "    lr_scheduler_type = \"linear\",\n",
    "    warmup_ratio= 6e-2,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='24544' max='24544' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [24544/24544 3:40:48, Epoch 2/2]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.603000</td>\n",
       "      <td>0.508186</td>\n",
       "      <td>0.795619</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.559800</td>\n",
       "      <td>0.475799</td>\n",
       "      <td>0.812124</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=24544, training_loss=0.6184384013092658, metrics={'train_runtime': 13249.0278, 'train_samples_per_second': 59.28, 'train_steps_per_second': 1.853, 'total_flos': 2.0971423089834394e+17, 'train_loss': 0.6184384013092658, 'epoch': 2.0})"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer = AdapterTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics = compute_metrics\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With only as little as  0.030% of the parameters, we were able to successfully achieve 0.81 accuracy within only two epochs! We also reduced the time to train the model drastically than if we were to fully fine-tune the model using traditional methods."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference\n",
    "\n",
    "With our tuned model using `reft` adapters, we can now do some inference on some new text!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "mapping = {\n",
    "    0: \"neutral\",\n",
    "    1: \"entailment\",\n",
    "    2: \"contradiction\"\n",
    "}\n",
    "\n",
    "def infer_text(text):\n",
    "    input_ids = tokenizer(text, truncation=True, padding='max_length', return_tensors = \"pt\")\n",
    "    outputs = model(input_ids = input_ids[\"input_ids\"], attention_mask = input_ids[\"attention_mask\"])\n",
    "    logits = outputs[\"logits\"]\n",
    "    p = torch.nn.functional.softmax(logits).detach().numpy()\n",
    "    prediction = np.argmax(p, axis = -1)\n",
    "    print(f\"The classification of this text is: {mapping[prediction[0]]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = [\"I like Apple Pie. I don't like Apple Pie\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The classification of this text is: contradiction\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\Jackson\\AppData\\Local\\Temp\\ipykernel_11608\\2176357842.py:12: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  p = torch.nn.functional.softmax(logits).detach().numpy()\n"
     ]
    }
   ],
   "source": [
    "infer_text(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving the adapter model\n",
    "\n",
    "If you would like to save your model or push it to HuggingFace, you can always do so with the below code. Make sure to sign in before you do"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0ec8f37df0ca459a967c43b02c8ed279",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_adapter(\"./reft_adapter\", \"loreft_adapter\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.push_adapter_to_hub(\n",
    "    \"roberta-base-reft-adapter\",\n",
    "    \"loreft_adapter\",\n",
    "    datasets_tag=\"mnli\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "adapter_hub",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
